【MediaPipe】HelloWorldのプログラム動作/処理を解析してみた
カフェチームの山本です。
前回はHelloWorldのプログラムを参考に、MediaPipeのフレームワークとしての動作/構成を学びました。
今回は、HelloWorldのプログラムの内部を詳細に見ながら、グラフに対してデータを入出力する方法 と ソースコードがどのように実行されているかを学びます。
今回も学習しただけであるため、特に新しい知見や結論はありませんが、ご参考になれば幸いです。
MediaPipeに関連する記事はこちらにまとめてあります。
HelloWorldの入出力
前回の記事の「HelloWorldのプログラム」で記載したように、HelloWorldのプログラムは、mediapipe/examples/desktop/hello_world/hello_world.cc から実行されています。中を見てみると、入出力に関して以下のことがわかります。
- 出力準備:初期化したGraphに対して、出力を受け取る用の OutputStreamPoller をくっつける。この際、Graph中の出力ストリームである "out" にアタッチする。(43~44行目)
- 入力:処理を開始したGraphに対して、入力ストリームである "in" を指定して、データを入力する。この際、データは Packet という型にラップし、タイムスタンプを付与する。(48~49行目)
- 出力:ポーラを介して、グラフ中のデータがなくなるまで出力を受け取る。この際、データの型は Packet で受け取るため、データを変換する。(55~57行目)
#include "mediapipe/framework/calculator_graph.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { ::mediapipe::Status PrintHelloWorld() { // Configures a simple graph, which concatenates 2 PassThroughCalculators. CalculatorGraphConfig config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"( input_stream: "in" output_stream: "out" node { calculator: "PassThroughCalculator" input_stream: "in" output_stream: "out1" } node { calculator: "PassThroughCalculator" input_stream: "out1" output_stream: "out" } )"); CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config)); ASSIGN_OR_RETURN(OutputStreamPoller poller, graph.AddOutputStreamPoller("out")); MP_RETURN_IF_ERROR(graph.StartRun({})); // Give 10 input packets that contains the same std::string "Hello World!". for (int i = 0; i < 10; ++i) { MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( "in", MakePacket<std::string>("Hello World!").At(Timestamp(i)))); } // Close the input stream "in". MP_RETURN_IF_ERROR(graph.CloseInputStream("in")); mediapipe::Packet packet; // Get the output packets std::string. while (poller.Next(&packet)) { LOG(INFO) << packet.Get<std::string>(); } return graph.WaitUntilDone(); } } // namespace mediapipe
Graph内のプログラム
Graph内のPassThroughCalculatorですが、mediapipe/examples/desktop/hello_world/BUILDを見ると、mediapipe/calculators/core内のpass_through_calculatorが実行されている(いそうな)ことがわかります。
cc_binary( name = "hello_world", srcs = ["hello_world.cc"], visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", ], )
mediapipe/calculators/core/BUILDを見ると、pass_through_calculatorは同じファイルのpass_through_calculator.ccを参照していることがわかります。
cc_library( name = "pass_through_calculator", srcs = ["pass_through_calculator.cc"], visibility = [ "//visibility:public", ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", ], alwayslink = 1, )
mediapipe/calculators/core/pass_through_calculator.ccを見ると、以下のことがわかります。
- PassThroughCalculatorが定義されており、REGISTER_CALCULATORで登録され、hello_world.ccのGraphの定義で使用されている。(29, 96行目)
- PassThroughCalculatorはCalculatorBaseを継承しており、静的関数GetCongractと、Open・Processがオーバライドされている。Closeは定義されていない。(31, 62, 79行目)
- ProcessはCalculatorContextという型の入力を受け取り、cc->Inputs().Get(id).Value()のようにして入力を受け取り、cc->Outputs().Get(id).AddPacket()のようにして出力する(90行目)
- この id には 0が入ります("in" や "out1" のような入力/出力の数に対応します)。cc->Outputs().Get(id).Name()で"in"や"out1"が得られます。
class PassThroughCalculator : public CalculatorBase { public: static ::mediapipe::Status GetContract(CalculatorContract* cc) { if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) { return ::mediapipe::InvalidArgumentError( "Input and output streams to PassThroughCalculator must use " "matching tags and indexes."); } for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { cc->Inputs().Get(id).SetAny(); cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Get(id)); } for (CollectionItemId id = cc->InputSidePackets().BeginId(); id < cc->InputSidePackets().EndId(); ++id) { cc->InputSidePackets().Get(id).SetAny(); } if (cc->OutputSidePackets().NumEntries() != 0) { if (!cc->InputSidePackets().TagMap()->SameAs( *cc->OutputSidePackets().TagMap())) { return ::mediapipe::InvalidArgumentError( "Input and output side packets to PassThroughCalculator must use " "matching tags and indexes."); } for (CollectionItemId id = cc->InputSidePackets().BeginId(); id < cc->InputSidePackets().EndId(); ++id) { cc->OutputSidePackets().Get(id).SetSameAs( &cc->InputSidePackets().Get(id)); } } return ::mediapipe::OkStatus(); } ::mediapipe::Status Open(CalculatorContext* cc) final { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { if (!cc->Inputs().Get(id).Header().IsEmpty()) { cc->Outputs().Get(id).SetHeader(cc->Inputs().Get(id).Header()); } } if (cc->OutputSidePackets().NumEntries() != 0) { for (CollectionItemId id = cc->InputSidePackets().BeginId(); id < cc->InputSidePackets().EndId(); ++id) { cc->OutputSidePackets().Get(id).Set(cc->InputSidePackets().Get(id)); } } cc->SetOffset(TimestampDiff(0)); return ::mediapipe::OkStatus(); } ::mediapipe::Status Process(CalculatorContext* cc) final { cc->GetCounter("PassThrough")->Increment(); if (cc->Inputs().NumEntries() == 0) { return tool::StatusStop(); } for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { if (!cc->Inputs().Get(id).IsEmpty()) { VLOG(3) << "Passing " << cc->Inputs().Get(id).Name() << " to " << cc->Outputs().Get(id).Name() << " at " << cc->InputTimestamp().DebugString(); cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value()); } } return ::mediapipe::OkStatus(); } }; REGISTER_CALCULATOR(PassThroughCalculator);
まとめ
今回は、HelloWorldのプログラムがどのように動いているかを調べるため、ソースコードを追って見てみました。
次回は、Multi Hand Trackingのプログラムを見ていきます。